# V Final exhibit - Figure 5: cost by Intensity and phat -------------------------
fig_Average_Monthly_Spending_by_Type <- function(dt_exh  = dt_for_exhibits_cancer,
                       dt_cst  = dt_cost_for_exhibits_cancer,
                       sample_n = "cancer", 
                       hosp_a = "Low",
                       hosp_b = "High",
                       hosp_a_name = "Inpatient :\nLow Intensity", 
                       hosp_b_name = "Inpatient :\nHigh Intensity",
                       loc = c(2900, 1150, 4200,
                               2300, 6600, 4400,
                               2500, 900, 4200)) {
  
  # create top level main_cat
  temp_dt <- merge(
    x = dt_cst[cost_date>=S_index_date_XX, 
               .(sum_cost = sum(actual_cost, na.rm=T)),
               by = .(id_var, S_index_date_XX, S_sample_source_XX,
                      cat = top_main_cat)],
    y = dt_exh[, .(id_var, S_index_date_XX, S_sample_source_XX,
                   DMG_died_within_365d, num_days_lived, prob_for_report)],
    by = c("id_var", "S_index_date_XX", "S_sample_source_XX"),
    all.x = T,
    all.y = T
  )[is.na(sum_cost), sum_cost := 0]
  # make sure that no obs in dt_cst hasn't any match at dt_exh
  if (sum(is.na(temp_dt$num_days_lived))!=0) {stop("Problem in the above merge!")}
  
  # reshape to wide and back to long, so all id will have all categories
  temp_dt_wide <- data.table::dcast.data.table(
    data=temp_dt,
    formula=id_var+S_index_date_XX+S_sample_source_XX+
      DMG_died_within_365d+num_days_lived+prob_for_report~cat,
    value.var="sum_cost",
    fill = 0
  )
  # validating that NA column (if exist) is only due to patients without any cost:
  if ("NA" %in% names(temp_dt_wide)) {
    if (sum(temp_dt_wide$`NA`)!=0) {stop("Problem in the above merge!")}
    temp_dt_wide[, `NA`:=NULL]  
  }
  # validate that no NA exist
  stopifnot(all(sapply(temp_dt_wide, function(x){sum(is.na(x))})==0))
  # reshape back to long
  cost_columns <- base::setdiff(names(temp_dt_wide),
                                c("id_var", 
                                  "S_index_date_XX", 
                                  "S_sample_source_XX",
                                  "DMG_died_within_365d", 
                                  "num_days_lived",
                                  "prob_for_report"))
  stopifnot(cost_columns == levels(temp_dt$cat))
  temp_dt_long <- data.table::melt.data.table(
    data = temp_dt_wide,
    id.vars = c("id_var", "S_index_date_XX", "S_sample_source_XX",
                "DMG_died_within_365d", "num_days_lived", "prob_for_report"),
    measure.vars = cost_columns,
    variable.name = "cat",
    value.name = "cost"
  )
  stopifnot(nrow(temp_dt_long)/length(cost_columns)==nrow(dt_exh))
  stopifnot(sum(is.na(temp_dt_long$cost))==0)
  stopifnot(abs(sum(temp_dt_long$cost)-sum(dt_exh$UTL_f365d_total_cost))/
              sum(temp_dt_long$cost)<0.01)
  
  if (sample_n == "cancer") {
    sample_full_name <- "Cancer Sample"
  } else if (sample_n == "all") {
    sample_full_name <- "General Population Sample"
  } else {
    stop("Bad definition of sample_n")
  }
  
  # for figure that doesn't split between Decedent and Survivor 
  dt_for_by_topCat_ex_unif <- 
    temp_dt_long[
      , .(sample = sample_full_name,
          ave_cost=(sum(cost)/sum(num_days_lived))*31, 
          .N), 
      by = .(cat, bins = plyr::round_any(prob_for_report, 0.1, floor))
      ][N>=min_obs_num]
  
  # for figure that does differentiate between Decedent and Survivor
  dt_for_by_topCat_ex_split <-
    temp_dt_long[
      , .(sample = sample_full_name,
          ave_cost=(sum(cost)/sum(num_days_lived))*31, 
          .N), 
      by = .(cat, 
             group = factor(ifelse(DMG_died_within_365d=="1", "Decedent", "Survivor"),
                            levels = c("Decedent", "Survivor")),
             bins = plyr::round_any(prob_for_report, 0.1, floor))
      ][N>=min_obs_num]
  
  
  # V Fig 5 Plot for cancer: #############d ----------------------
  # labels dt
  labs_dt_cancer_top <- data.table::data.table(
    labels =rep(c(hosp_b_name,
                  hosp_a_name,
                  "All Other\nServices"),3),
    group = c( rep("All", 3),rep("Decedent", 3), rep("Survivor", 3)),
    x = c( 0.3,0.3,0.3,
           0.7, 0.7, 0.7,
           0.5, 0.5, 0.5 ) ,
    y = loc
  )[, group := factor(group, 
                      levels = c("All","Decedent", "Survivor"))]
  # the actual plot
  dt_for_plot <-rbind(dt_for_by_topCat_ex_split[sample=="Cancer Sample"],
                      dt_for_by_topCat_ex_unif[sample=="Cancer Sample"][,group:="All"])[,
                                                                                        group := factor(group, levels = c("All","Decedent","Survivor"))
                                                                                        ]
  write.csv(dt_for_plot,paste0("Average_Monthly_Spending_by_Type_",hosp_a,"_",hosp_b,".csv"))

  pdf(paste0("Average_Monthly_Spending_by_Type_",hosp_a,"_",hosp_b,".pdf"), width = 10, height = 7.5)
  print(
    ggplot(data=dt_for_plot, 
           aes(x=bins, y=ave_cost)) + 
      geom_line(aes(linetype=cat ), size =1.2 ) +
      geom_text(data=labs_dt_cancer_top, 
                aes(label=labels, x=x, y=y),
                size=inplot_text_size) +
      facet_grid(~group) +
      scale_x_continuous(breaks = seq(0, 0.9, 0.2) ) +
      scale_y_continuous(labels = scales::comma, limits = c(0, 8000)) +
      scale_color_manual(guide=FALSE, values = c("#41AB5D","#F8766D", "#00BFC4")) +
      scale_linetype_manual(guide=FALSE,
                            name="Spending\nCategory",
                            values=c("twodash",  "dotted","solid")) +
      labs(x = "Initial Prognosis (One-Year Mortality Risk)",
           y = "Average Monthly Spending (NIS)")+
      theme(legend.key.width =  unit(5,"line"),
            axis.title.x = element_text(size = 16),
            axis.title.y = element_text(size = 16))
  )
  dev.off()
  
}


